{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "118UKH5bWCGa"
      },
      "source": [
        "# DALL·E mini - Inference pipeline\n",
        "\n",
        "*Generate images from a text prompt*\n",
        "\n",
        "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
        "\n",
        "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
        "\n",
        "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
        "\n",
        "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dS8LbaonYm3a"
      },
      "source": [
        "## 🛠️ Installation and set-up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzjAM2GBYpZX"
      },
      "outputs": [],
      "source": [
        "# Install required libraries\n",
        "!pip install -q git+https://github.com/huggingface/transformers.git\n",
        "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
        "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozHzTkyv8cqU"
      },
      "source": [
        "We load required models:\n",
        "* dalle·mini for text to encoded images\n",
        "* VQGAN for decoding images\n",
        "* CLIP for scoring predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6CxW2o42f-w"
      },
      "outputs": [],
      "source": [
        "# Model references\n",
        "\n",
        "# dalle-mini\n",
        "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
        "DALLE_COMMIT_ID = None\n",
        "\n",
        "# VQGAN model\n",
        "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
        "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
        "\n",
        "# CLIP model\n",
        "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
        "CLIP_COMMIT_ID = None"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "If your hardware can handle it, you can use dalle-mega instead of dalle-mini.\n",
        "\n",
        "**Note: on free Colab, you will most likely not be able to load dalle-mega so don't run the below cell to use dalle-mini instead.**"
      ],
      "metadata": {
        "id": "Jy01V3OltG5q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# dalle-mega\n",
        "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\""
      ],
      "metadata": {
        "id": "VRRX_g9vtMZG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yv-aR3t4Oe5v"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "# check how many devices are available\n",
        "jax.local_device_count()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "92zYmvsQ38vL"
      },
      "outputs": [],
      "source": [
        "# Load models & tokenizer\n",
        "from dalle_mini import DalleBart, DalleBartProcessor\n",
        "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
        "from transformers import CLIPProcessor, FlaxCLIPModel\n",
        "\n",
        "# Load dalle-mini\n",
        "model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)\n",
        "\n",
        "# Load VQGAN\n",
        "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
        "\n",
        "# Load CLIP\n",
        "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
        "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o_vH2X1tDtzA"
      },
      "source": [
        "Model parameters are replicated on each device for faster inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wtvLoM48EeVw"
      },
      "outputs": [],
      "source": [
        "from flax.jax_utils import replicate\n",
        "\n",
        "params = replicate(params)\n",
        "vqgan._params = replicate(vqgan.params)\n",
        "clip._params = replicate(clip.params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0A9AHQIgZ_qw"
      },
      "source": [
        "Model functions are compiled and parallelized to take advantage of multiple devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sOtoOmYsSYPz"
      },
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "\n",
        "# model inference\n",
        "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
        "def p_generate(\n",
        "    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
        "):\n",
        "    return model.generate(\n",
        "        **tokenized_prompt,\n",
        "        prng_key=key,\n",
        "        params=params,\n",
        "        top_k=top_k,\n",
        "        top_p=top_p,\n",
        "        temperature=temperature,\n",
        "        condition_scale=condition_scale,\n",
        "    )\n",
        "\n",
        "\n",
        "# decode image\n",
        "@partial(jax.pmap, axis_name=\"batch\")\n",
        "def p_decode(indices, params):\n",
        "    return vqgan.decode_code(indices, params=params)\n",
        "\n",
        "\n",
        "# score images\n",
        "@partial(jax.pmap, axis_name=\"batch\")\n",
        "def p_clip(inputs, params):\n",
        "    logits = clip(params=params, **inputs).logits_per_image\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HmVN6IBwapBA"
      },
      "source": [
        "Keys are passed to the model on each device to generate unique inference per device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4CTXmlUkThhX"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "\n",
        "# create a random key\n",
        "seed = random.randint(0, 2**32 - 1)\n",
        "key = jax.random.PRNGKey(seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BrnVyCo81pij"
      },
      "source": [
        "## 🖍 Text Prompt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rsmj0Aj5OQox"
      },
      "source": [
        "Our model requires processing prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YjjhUychOVxm"
      },
      "outputs": [],
      "source": [
        "from dalle_mini import DalleBartProcessor\n",
        "\n",
        "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BQ7fymSPyvF_"
      },
      "source": [
        "Let's define a text prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_0vI9ge1oKr"
      },
      "outputs": [],
      "source": [
        "prompt = \"sunset over a lake in the mountains\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VKjEZGjtO49k"
      },
      "outputs": [],
      "source": [
        "tokenized_prompt = processor([prompt])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-CEJBnuJOe5z"
      },
      "source": [
        "Finally we replicate it onto each device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lQePgju5Oe5z"
      },
      "outputs": [],
      "source": [
        "tokenized_prompt = replicate(tokenized_prompt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "phQ9bhjRkgAZ"
      },
      "source": [
        "## 🎨 Generate images\n",
        "\n",
        "We generate images using dalle-mini model and decode them with the VQGAN."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d0wVkXpKqnHA"
      },
      "outputs": [],
      "source": [
        "# number of predictions\n",
        "n_predictions = 16\n",
        "\n",
        "# We can customize top_k/top_p used for generating samples\n",
        "gen_top_k = None\n",
        "gen_top_p = None\n",
        "temperature = None\n",
        "cond_scale = 3.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SDjEx9JxR3v8"
      },
      "outputs": [],
      "source": [
        "from flax.training.common_utils import shard_prng_key\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm.notebook import trange\n",
        "\n",
        "# generate images\n",
        "images = []\n",
        "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
        "    # get a new key\n",
        "    key, subkey = jax.random.split(key)\n",
        "    # generate images\n",
        "    encoded_images = p_generate(\n",
        "        tokenized_prompt,\n",
        "        shard_prng_key(subkey),\n",
        "        params,\n",
        "        gen_top_k,\n",
        "        gen_top_p,\n",
        "        temperature,\n",
        "        cond_scale,\n",
        "    )\n",
        "    # remove BOS\n",
        "    encoded_images = encoded_images.sequences[..., 1:]\n",
        "    # decode images\n",
        "    decoded_images = p_decode(encoded_images, vqgan.params)\n",
        "    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
        "    for img in decoded_images:\n",
        "        images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tw02wG9zGmyB"
      },
      "source": [
        "Let's calculate their score with CLIP."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FoLXpjCmGpju"
      },
      "outputs": [],
      "source": [
        "from flax.training.common_utils import shard\n",
        "\n",
        "# get clip scores\n",
        "clip_inputs = clip_processor(\n",
        "    text=[prompt] * jax.device_count(),\n",
        "    images=images,\n",
        "    return_tensors=\"np\",\n",
        "    padding=\"max_length\",\n",
        "    max_length=77,\n",
        "    truncation=True,\n",
        ").data\n",
        "logits = p_clip(shard(clip_inputs), clip.params)\n",
        "logits = logits.squeeze().flatten()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AAWRm70LgED"
      },
      "source": [
        "Let's display images ranked by CLIP score."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zsgxxubLLkIu"
      },
      "outputs": [],
      "source": [
        "print(f\"Prompt: {prompt}\\n\")\n",
        "for idx in logits.argsort()[::-1]:\n",
        "    display(images[idx])\n",
        "    print(f\"Score: {logits[idx]:.2f}\\n\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "DALL·E mini - Inference pipeline.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}